{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 生成验证码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from captcha.image import ImageCaptcha\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "import string\n",
    "characters = string.digits + string.ascii_uppercase\n",
    "print(characters)\n",
    "\n",
    "width, height, n_len, n_class = 170, 80, 4, len(characters)+1\n",
    "\n",
    "generator = ImageCaptcha(width=width, height=height)\n",
    "random_str = ''.join([random.choice(characters) for j in range(4)])\n",
    "img = generator.generate_image(random_str)\n",
    "\n",
    "plt.imshow(img)\n",
    "plt.title(random_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义 CTC Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras import backend as K\n",
    "\n",
    "def ctc_lambda_func(args):\n",
    "    y_pred, labels, input_length, label_length = args\n",
    "    y_pred = y_pred[:, 2:, :]\n",
    "    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 构建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import *\n",
    "from keras.layers import *\n",
    "from keras.optimizers import *\n",
    "rnn_size = 128\n",
    "\n",
    "input_tensor = Input((width, height, 3))\n",
    "x = input_tensor\n",
    "x = Lambda(lambda x:(x-127.5)/127.5)(x)\n",
    "for i in range(3):\n",
    "    for j in range(2):\n",
    "        x = Conv2D(32*2**i, 3, kernel_initializer='he_uniform')(x)\n",
    "        x = BatchNormalization()(x)\n",
    "        x = Activation('relu')(x)\n",
    "    x = MaxPooling2D((2, 2))(x)\n",
    "\n",
    "conv_shape = x.get_shape().as_list()\n",
    "rnn_length = conv_shape[1]\n",
    "rnn_dimen = conv_shape[2]*conv_shape[3]\n",
    "print(conv_shape, rnn_length, rnn_dimen)\n",
    "x = Reshape(target_shape=(rnn_length, rnn_dimen))(x)\n",
    "rnn_length -= 2\n",
    "\n",
    "x = Dense(rnn_size, kernel_initializer='he_uniform')(x)\n",
    "x = BatchNormalization()(x)\n",
    "x = Activation('relu')(x)\n",
    "x = Dropout(0.2)(x)\n",
    "\n",
    "gru_1 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_uniform', name='gru1')(x)\n",
    "gru_1b = GRU(rnn_size, return_sequences=True, kernel_initializer='he_uniform', \n",
    "             go_backwards=True, name='gru1_b')(x)\n",
    "x = add([gru_1, gru_1b])\n",
    "\n",
    "gru_2 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_uniform', name='gru2')(x)\n",
    "gru_2b = GRU(rnn_size, return_sequences=True, kernel_initializer='he_uniform', \n",
    "             go_backwards=True, name='gru2_b')(x)\n",
    "x = concatenate([gru_2, gru_2b])\n",
    "\n",
    "x = Dropout(0.2)(x)\n",
    "x = Dense(n_class, activation='softmax')(x)\n",
    "base_model = Model(inputs=input_tensor, outputs=x)\n",
    "\n",
    "labels = Input(name='the_labels', shape=[n_len], dtype='float32')\n",
    "input_length = Input(name='input_length', shape=[1], dtype='int64')\n",
    "label_length = Input(name='label_length', shape=[1], dtype='int64')\n",
    "loss_out = Lambda(ctc_lambda_func, output_shape=(1,), \n",
    "                  name='ctc')([x, labels, input_length, label_length])\n",
    "\n",
    "model = Model(inputs=[input_tensor, labels, input_length, label_length], outputs=[loss_out])\n",
    "model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adam')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from keras.utils.vis_utils import plot_model\n",
    "from IPython.display import Image\n",
    "\n",
    "plot_model(model, to_file=\"model.png\", show_shapes=True)\n",
    "Image('model.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def gen(batch_size=128):\n",
    "    X = np.zeros((batch_size, width, height, 3), dtype=np.uint8)\n",
    "    y = np.zeros((batch_size, n_len), dtype=np.uint8)\n",
    "    generator = ImageCaptcha(width=width, height=height)\n",
    "    while True:\n",
    "        for i in range(batch_size):\n",
    "            random_str = ''.join([random.choice(characters) for j in range(n_len)])\n",
    "            X[i] = np.array(generator.generate_image(random_str)).transpose(1, 0, 2)\n",
    "            y[i] = [characters.find(x) for x in random_str]\n",
    "        yield [X, y, np.ones(batch_size)*rnn_length, np.ones(batch_size)*n_len], np.ones(batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(X_vis, y_vis, input_length_vis, label_length_vis), _ = next(gen(1))\n",
    "print(X_vis.shape, y_vis, input_length_vis, label_length_vis)\n",
    "\n",
    "plt.imshow(X_vis[0].transpose(1, 0, 2))\n",
    "plt.title(''.join([characters[i] for i in y_vis[0]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def evaluate(batch_size=128, steps=10):\n",
    "    batch_acc = 0\n",
    "    generator = gen(batch_size)\n",
    "    for i in range(steps):\n",
    "        [X_test, y_test, _, _], _  = next(generator)\n",
    "        y_pred = base_model.predict(X_test)\n",
    "        shape = y_pred[:,2:,:].shape\n",
    "        ctc_decode = K.ctc_decode(y_pred[:,2:,:], input_length=np.ones(shape[0])*shape[1])[0][0]\n",
    "        out = K.get_value(ctc_decode)[:, :n_len]\n",
    "        if out.shape[1] == n_len:\n",
    "            batch_acc += (y_test == out).all(axis=1).mean()\n",
    "    return batch_acc / steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.callbacks import *\n",
    "\n",
    "class Evaluator(Callback):\n",
    "    def __init__(self):\n",
    "        self.accs = []\n",
    "    \n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        acc = evaluate(steps=20)*100\n",
    "        self.accs.append(acc)\n",
    "        print('')\n",
    "        print('acc: %f%%' % acc)\n",
    "\n",
    "evaluator = Evaluator()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "h = model.fit_generator(gen(128), steps_per_epoch=400, epochs=20,\n",
    "                        callbacks=[evaluator],\n",
    "                        validation_data=gen(128), validation_steps=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=Adam(1e-4))\n",
    "h2 = model.fit_generator(gen(128), steps_per_epoch=400, epochs=20, \n",
    "                        callbacks=[evaluator],\n",
    "                        validation_data=gen(128), validation_steps=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 4))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(h.history['loss'] + h2.history['loss'])\n",
    "plt.plot(h.history['val_loss'] + h2.history['val_loss'])\n",
    "plt.legend(['loss', 'val_loss'])\n",
    "plt.ylabel('loss')\n",
    "plt.xlabel('epoch')\n",
    "plt.ylim(0, 1)\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(evaluator.accs)\n",
    "plt.ylabel('acc')\n",
    "plt.xlabel('epoch')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(X_vis, y_vis, input_length_vis, label_length_vis), _ = next(gen(12))\n",
    "\n",
    "y_pred = base_model.predict(X_vis)\n",
    "shape = y_pred[:,2:,:].shape\n",
    "ctc_decode = K.ctc_decode(y_pred[:,2:,:], input_length=np.ones(shape[0])*shape[1])[0][0]\n",
    "out = K.get_value(ctc_decode)[:, :4]\n",
    "\n",
    "plt.figure(figsize=(16, 8))\n",
    "for i in range(12):\n",
    "    plt.subplot(3, 4, i+1)\n",
    "    plt.imshow(X_vis[i].transpose(1, 0, 2))\n",
    "    plt.title('pred:%s\\nreal :%s' % (''.join([characters[x] for x in out[i]]), \n",
    "                                     ''.join([characters[x] for x in y_vis[i]])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(X_vis, y_vis, input_length_vis, label_length_vis), _ = next(gen(10000))\n",
    "\n",
    "y_pred = base_model.predict(X_vis, verbose=1)\n",
    "shape = y_pred[:,2:,:].shape\n",
    "ctc_decode = K.ctc_decode(y_pred[:,2:,:], input_length=np.ones(shape[0])*shape[1])[0][0]\n",
    "out = K.get_value(ctc_decode)[:, :4]\n",
    "\n",
    "(y_vis == out).all(axis=1).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "Counter(''.join([characters[i] for i in y_vis[y_vis != out]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "base_model.save('model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "characters2 = characters + ' '\n",
    "\n",
    "generator = ImageCaptcha(width=width, height=height)\n",
    "random_str = '0O0O'\n",
    "X_test = np.array(generator.generate_image(random_str))\n",
    "X_test = X_test.transpose(1, 0, 2)\n",
    "X_test = np.expand_dims(X_test, 0)\n",
    "\n",
    "y_pred = base_model.predict(X_test)\n",
    "shape = y_pred[:,2:,:].shape\n",
    "ctc_decode = K.ctc_decode(y_pred[:,2:,:], input_length=np.ones(shape[0])*shape[1])[0][0]\n",
    "out = K.get_value(ctc_decode)[:, :4]\n",
    "out = ''.join([characters[x] for x in out[0]])\n",
    "\n",
    "plt.imshow(X_test[0].transpose(1, 0, 2))\n",
    "plt.title('pred:' + str(out))\n",
    "\n",
    "argmax = np.argmax(y_pred, axis=2)[0]\n",
    "list(zip(argmax, ''.join([characters2[x] for x in argmax])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
